package vip.shuai7boy.trafficTemp.areaRoadFlow;

import com.alibaba.fastjson.JSONObject;
import org.apache.spark.SparkConf;
import org.apache.spark.api.java.JavaPairRDD;
import org.apache.spark.api.java.JavaRDD;
import org.apache.spark.api.java.JavaSparkContext;
import org.apache.spark.api.java.function.Function;
import org.apache.spark.api.java.function.PairFunction;
import org.apache.spark.sql.*;
import org.apache.spark.sql.types.DataTypes;
import org.apache.spark.sql.types.StructField;
import org.apache.spark.sql.types.StructType;
import scala.Tuple2;
import vip.shuai7boy.trafficTemp.conf.ConfigurationManager;
import vip.shuai7boy.trafficTemp.constant.Constants;
import vip.shuai7boy.trafficTemp.dao.ITaskDAO;
import vip.shuai7boy.trafficTemp.dao.factory.DAOFactory;
import vip.shuai7boy.trafficTemp.domain.Task;
import vip.shuai7boy.trafficTemp.util.ParamUtils;
import vip.spark.spark.test.MockData;

import java.util.*;

/**
 * 计算每一个区域车流量最多的前3条道路。
 * （每个区域多条道路，每条道路多个卡口，每个卡口多个摄像头）
 * 根据区域划分，计算每条道路的车流量，
 * 时间，卡口ID，监控ID，车牌，拍照时间，车速，道路ID，区域ID
 * 区域1，（road1，road2，road3）
 * 区域2，（road1，road2，road3）
 * 1.区域，ROW
 * 2.区域，（【道路，车流量】）
 * 3.
 * <p>
 * 这个一个分组topN 利用Spark SQL分组topN。
 */
public class AreaTop3RoadFlowAnalyze {
    public static void main(String[] args) {
        /**
         * 判断程序是否在本地运行
         */
        JavaSparkContext sc = null;
        SparkSession spark = null;
        Boolean onLocal = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
        if (onLocal) {
            //构建spark运行时环境
            SparkConf conf = new SparkConf().setAppName(Constants.SPARK_APP_NAME).setMaster("local");
            sc = new JavaSparkContext(conf);
            spark = SparkSession.builder().getOrCreate();
            MockData.mock(sc, spark);

        } else {

            System.out.println("++++++++++++++++++++++++++++++++++++++开启hive的支持");
            spark = SparkSession.builder().appName(Constants.SPARK_APP_NAME).config("spark.sql" +
                    ".autoBroadcastJoinThreshold", "1048576000").enableHiveSupport().getOrCreate();
            sc = new JavaSparkContext(spark.sparkContext());
            spark.sql("use traffic");

        }
        //注册自定义函数
        spark.udf().register("concat_String_string", new ConcatStringStringUDF(), DataTypes.StringType);
        spark.udf().register("random_prefix", new RandomPrefixUDF(), DataTypes.StringType);
        spark.udf().register("remove_random_prefix", new RemoveRandomPrefixUDF(), DataTypes.StringType);
        spark.udf().register("group_concat_distinct", new GroupConcatDistinctUDAF());
        // 获取命令行传入的taskid，查询对应的任务参数
        ITaskDAO taskDAO = DAOFactory.getTaskDAO();
        long taskid = ParamUtils.getTaskIdFromArgs(args, Constants.SPARK_LOCAL_TASKID_TOPN_MONITOR_FLOW);
        Task task = taskDAO.findTaskById(taskid);
        if (task == null) {
            return;
        }
        JSONObject taskParam = JSONObject.parseObject(task.getTaskParams());
        /**
         * 获取指定日期内车辆信息1
         * (areaId,row)
         */
        JavaPairRDD<String, Row> areaId2DetailInfos = getInfosByDateRDD(spark, taskParam);
        /**
         * 从mysql中获取区域信息2
         * (areaId,areaName)
         */
        JavaPairRDD<String, String> areaId2AreaInfoRDD = getAreaId2AreaInfoRDD(spark);
        /**
         * 补全区域信息，添加区域名称  3
         * 生成基础临时表
         * temp_car_flow_base
         *
         */
        generateTempRoadFlowBasicTable(spark, areaId2DetailInfos, areaId2AreaInfoRDD);

        /**
         * 统计各个区域车段流量的临时表 4
         *
         */
        generateTempAreaRoadFlowTable(spark);

        /**
         * 计算每个区域排名前三的道路 5
         */
        getAreaTop3RoadFolwRDD(spark);

        System.out.println("++++++++++++++++++full complete+++++++++++++");
        sc.close();
        spark.close();

    }


    /**
     * 获取每个区域topN路段  5
     *
     * @param spark
     */
    public static void getAreaTop3RoadFolwRDD(SparkSession spark) {
        String sql = ""
                + "SELECT "
                + "area_name,"
                + "road_id,"
                + "car_count,"
                + "monitor_infos, "
                + "CASE "
                + "WHEN car_count > 170 THEN 'A LEVEL' "
                + "WHEN car_count > 160 AND car_count <= 170 THEN 'B LEVEL' "
                + "WHEN car_count > 150 AND car_count <= 160 THEN 'C LEVEL' "
                + "ELSE 'D LEVEL' "
                + "END flow_level "
                + "FROM ("
                + "SELECT "
                + "area_name,"
                + "road_id,"
                + "car_count,"
                + "monitor_infos,"
                + "row_number() OVER (PARTITION BY area_name ORDER BY car_count DESC) rn "
                + "FROM tmp_area_road_flow_count "
                + ") tmp "
                + "WHERE rn <=3";
        Dataset<Row> result = spark.sql(sql);
        System.out.println("--------最终的结果-------");
        result.show();
        //写入hive，要有result这个database库
        spark.sql("use result");
        spark.sql("drop table if exists result.areaTop3Road");
        result.write().mode(SaveMode.Overwrite).saveAsTable("areaTop3Road");
    }


    /**
     * 统计每条道路车流量  4
     *
     * @param spark
     */
    public static void generateTempAreaRoadFlowTable(SparkSession spark) {
        String sql =
                "SELECT "
                        + "area_name,"
                        + "road_id,"
                        + "count(*) car_count,"
                        + "group_concat_distinct(monitor_id) monitor_infos "
                        + "FROM tmp_car_flow_basic "
                        + "GROUP BY area_name,road_id";
        Dataset<Row> ds = spark.sql(sql);
        ds.registerTempTable("tmp_area_road_flow_count");
    }

    /**
     * 关联添加区域名称，并将数据注册成临时表temp_car_flow_basic  3
     *
     * @param spark
     * @param areaId2DetailInfos
     * @param areaId2AreaInfoRDD
     */
    private static void generateTempRoadFlowBasicTable(SparkSession spark,
                                                       JavaPairRDD<String, Row> areaId2DetailInfos, JavaPairRDD<String, String> areaId2AreaInfoRDD) {


        JavaRDD<Row> tmpRowRDD = areaId2DetailInfos.join(areaId2AreaInfoRDD).map(
                new Function<Tuple2<String, Tuple2<Row, String>>, Row>() {

                    private static final long serialVersionUID = 1L;

                    @Override
                    public Row call(Tuple2<String, Tuple2<Row, String>> tuple) throws Exception {
                        String areaId = tuple._1;
                        Row carFlowDetailRow = tuple._2._1;
                        String areaName = tuple._2._2;

                        String roadId = carFlowDetailRow.getAs("road_id");
                        String monitorId = carFlowDetailRow.getAs("monitor_id");
                        String car = carFlowDetailRow.getAs("car");


                        return RowFactory.create(areaId, areaName, roadId, monitorId, car);
                    }
                });

        List<StructField> structFields = new ArrayList<>();
        structFields.add(DataTypes.createStructField("area_id", DataTypes.StringType, true));
        structFields.add(DataTypes.createStructField("area_name", DataTypes.StringType, true));
        structFields.add(DataTypes.createStructField("road_id", DataTypes.StringType, true));
        structFields.add(DataTypes.createStructField("monitor_id", DataTypes.StringType, true));
        structFields.add(DataTypes.createStructField("car", DataTypes.StringType, true));

        
        
        StructType schema = DataTypes.createStructType(structFields);

        Dataset<Row> df = spark.createDataFrame(tmpRowRDD, schema);

        df.createOrReplaceTempView("tmp_car_flow_basic");

    }

    /**
     * 获取区域信息 2
     * @param spark
     * @return
     */
    private static JavaPairRDD<String, String> getAreaId2AreaInfoRDD(SparkSession spark) {
        String url = null;
        String user = null;
        String password = null;

        boolean local = ConfigurationManager.getBoolean(Constants.SPARK_LOCAL);
        //获取Mysql数据库的url,user,password信息
        if (local) {
            url = ConfigurationManager.getProperty(Constants.JDBC_URL);
            user = ConfigurationManager.getProperty(Constants.JDBC_USER);
            password = ConfigurationManager.getProperty(Constants.JDBC_PASSWORD);
        } else {
            url = ConfigurationManager.getProperty(Constants.JDBC_URL_PROD);
            user = ConfigurationManager.getProperty(Constants.JDBC_USER_PROD);
            password = ConfigurationManager.getProperty(Constants.JDBC_PASSWORD_PROD);
        }

        Map<String, String> options = new HashMap<>();        
        options.put("url", url);
        options.put("driver", "com.mysql.jdbc.Driver");
        options.put("user", user);
        options.put("password", password);
        options.put("dbtable", "area_info");

        // 通过SQLContext去从MySQL中查询数据
        Dataset<Row> areaInfoDF = spark.read().format("jdbc").options(options).load();
        System.out.println("------------Mysql数据库中的表area_info数据为------------");
        areaInfoDF.show();
        // 返回RDD
        JavaRDD<Row> areaInfoRDD = areaInfoDF.javaRDD();

        JavaPairRDD<String, String> areaid2areaInfoRDD = areaInfoRDD.mapToPair(
                new PairFunction<Row, String, String>() {

                    private static final long serialVersionUID = 1L;

                    @Override
                    public Tuple2<String, String> call(Row row) throws Exception {                        
                        String areaid = String.valueOf(row.get(0));
                        String areaname = String.valueOf(row.get(1));
                        return new Tuple2<>(areaid, areaname);
                    }
                });

        return areaid2areaInfoRDD;
    }


    /**
     * 获取指定日期内车辆信息    1*
     * @param spark
     * @param taskParam
     * @return
     */
    private static JavaPairRDD<String, Row> getInfosByDateRDD(SparkSession spark, JSONObject taskParam) {        
        String startDate = ParamUtils.getParam(taskParam, Constants.PARAM_START_DATE);
        String endDate = ParamUtils.getParam(taskParam, Constants.PARAM_END_DATE);

        String sql = "SELECT "
                + "monitor_id,"
                + "car,"
                + "road_id,"
                + "area_id "
                + "FROM	traffic.monitor_flow_action "
                + "WHERE date >= '" + startDate + "'"
                + "AND date <= '" + endDate + "'";

        Dataset<Row> df = spark.sql(sql);
        return df.javaRDD().mapToPair(new PairFunction<Row, String, Row>() {

            /**
             *
             */
            private static final long serialVersionUID = 1L;

            @Override
            public Tuple2<String, Row> call(Row row) throws Exception {
                String areaId = row.getAs("area_id");
                return new Tuple2<>(areaId, row);
            }
        });
    }

}
